from keras.models import Sequential
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.layers.core import Dense, Dropout, Activation, Flatten
import h5py
from keras.layers.normalization import BatchNormalization
# from seya.layers.attention import SpatialTransformer, ST2
import numpy as np
from utils.extra import Softmax4D


def VGG_16(config=None, heat_map=False):
    img_chnnel = 3
    trainable_conv = False
    trainable_fc = False
    if config['to_gray'] == 'True':
        img_chnnel = 1
    if config['trainable_conv'] == 'True':
        trainable_conv = True
    if config['trainable_fc'] == 'True':
        trainable_fc = True
    model = Sequential()
    if heat_map:
        model.add(ZeroPadding2D((1, 1), input_shape=(img_chnnel, None, None)))
    else:
        model.add(ZeroPadding2D((1, 1), input_shape=(img_chnnel, 224, 224)))
    model.add(Convolution2D(64, 3, 3, activation='linear', trainable=trainable_fc, name='conv1_1'))
    model.add(BatchNormalization(name='bn1_1'))
    model.add(Activation('relu', name='relu1_1'))
    # model.add(Dropout(0.25, name='d_11'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(64, 3, 3, activation='linear', trainable=trainable_fc, name='conv1_2'))
    model.add(BatchNormalization(name='bn1_2'))
    model.add(Activation('relu', name='relu1_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_12'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='linear', trainable=trainable_fc, name='conv2_1'))
    model.add(BatchNormalization(name='bn2_1'))
    model.add(Activation('relu', name='relu2_1'))
    # model.add(Dropout(0.25, name='d_21'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='linear', trainable=trainable_fc, name='conv2_2'))
    model.add(BatchNormalization(name='bn2_2'))
    model.add(Activation('relu', name='relu2_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_22'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='linear', trainable=trainable_fc, name='conv3_1'))
    model.add(BatchNormalization(name='bn3_1'))
    model.add(Activation('relu', name='relu3_1'))
    # model.add(Dropout(0.25, name='d_31'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='linear', trainable=trainable_fc, name='conv3_2'))
    model.add(BatchNormalization(name='bn3_2'))
    model.add(Activation('relu', name='relu3_2'))
    # model.add(Dropout(0.25, name='d_32'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='linear', trainable=trainable_fc, name='conv3_3'))
    model.add(BatchNormalization(name='bn3_3'))
    model.add(Activation('relu', name='relu3_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_33'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='linear', trainable=trainable_fc, name='conv4_1'))
    model.add(BatchNormalization(name='bn4_1'))
    model.add(Activation('relu', name='relu4_1'))
    # model.add(Dropout(0.25, name='d_41'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='linear', trainable=trainable_fc, name='conv4_2'))
    model.add(BatchNormalization(name='bn4_2'))
    model.add(Activation('relu', name='relu4_2'))
    # model.add(Dropout(0.25, name='d_42'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='linear', trainable=trainable_fc, name='conv4_3'))
    model.add(BatchNormalization(name='bn4_3'))
    model.add(Activation('relu', name='relu4_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_43'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='linear', trainable=trainable_fc, name='conv5_1'))
    model.add(BatchNormalization(name='bn5_1'))
    model.add(Activation('relu', name='relu5_1'))
    # model.add(Dropout(0.25, name='d_51'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='linear', trainable=trainable_fc, name='conv5_2'))
    model.add(BatchNormalization(name='bn5_2'))
    model.add(Activation('relu', name='relu5_2'))
    # model.add(Dropout(0.25, name='d_52'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='linear', trainable=trainable_fc, name='conv5_3'))
    model.add(BatchNormalization(name='bn5_3'))
    model.add(Activation('relu', name='relu5_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_53'))

    if heat_map:
        model.add(Convolution2D(4096, 7, 7, activation="relu", name="dense_1"))
        model.add(Convolution2D(4096, 1, 1, activation="relu", name="dense_2"))
        model.add(Convolution2D(10, 1, 1, name="dense_3"))
        model.add(Softmax4D(axis=1, name="softmax"))
    else:
        model.add(Flatten())
        model.add(Dense(4096, activation='relu', trainable=trainable_fc, name="dense_1"))
        model.add(BatchNormalization(name='bn6_1'))
        model.add(Dropout(0.0))
        model.add(Dense(4096, activation='relu', trainable=trainable_fc, name="dense_2"))
        model.add(BatchNormalization(name='bn6_2'))
        model.add(Dropout(0.0))
        if config['weights_path']:
            weights_path = '/media/dell/delldisk/dell/ExData/' + config['weights_path']
            f = h5py.File(weights_path)
            flag = 0
            for k in range(len(model.layers)):
                if k >= len(model.layers):
                    # we don't look at the last (fully-connected) layers in the savefile
                    break
                print model.layers[k].name
                if model.layers[k].name.startswith('d_') or model.layers[k].name.startswith('bn') or model.layers[
                        k].name.startswith('relu'):
                    continue
                g = f['layer_{}'.format(flag)]
                weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
                model.layers[k].set_weights(weights)
                flag += 1
            f.close()
        model.add(Dense(10, activation=config['classifier'], name="dense_3"))
    return model


def VGG_16_Ori(config=None, heat_map=False, extract_feats=False):
    img_chnnel = 3
    trainable_conv = False
    trainable_fc = False
    if config['to_gray'] == 'True':
        img_chnnel = 1
    if config['trainable_conv'] == 'True':
        trainable_conv = True
    if config['trainable_fc'] == 'True':
        trainable_fc = True
    model = Sequential()
    if heat_map:
        model.add(ZeroPadding2D((1, 1), input_shape=(img_chnnel, None, None)))
    else:
        model.add(ZeroPadding2D((1, 1), input_shape=(img_chnnel, config['resize'][1], config['resize'][0])))
    model.add(Convolution2D(64, 3, 3, activation='relu', trainable=trainable_conv, name='conv1_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(64, 3, 3, activation='relu', trainable=trainable_conv, name='conv1_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_1'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='relu', trainable=trainable_conv, name='conv2_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='relu', trainable=trainable_conv, name='conv2_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_2'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_fc, name='conv3_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_fc, name='conv3_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_fc, name='conv3_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_3'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_fc, name='conv4_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_fc, name='conv4_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_fc, name='conv4_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.5, name='d_4'))

    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_fc, name='conv5_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_fc, name='conv5_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_fc, name='conv5_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.5, name='dr_5'))
    if config['weights_path']:
        weights_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/ExData/' + config['weights_path']
        f = h5py.File(weights_path)
        flag = 0
        for k in range(len(model.layers)):
            if k >= len(model.layers):
                # we don't look at the last (fully-connected) layers in the savefile
                break
            if model.layers[k].name.startswith('dr_'):
                continue
            g = f['layer_{}'.format(flag)]
            weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
            model.layers[k].set_weights(weights)
            flag += 1
        f.close()

    if heat_map:
        model.add(Convolution2D(4096, 7, 7, activation="relu", name="dense_1"))
        model.add(Convolution2D(4096, 1, 1, activation="relu", name="dense_2"))
        model.add(Convolution2D(10, 1, 1, name="dense_3"))
        model.add(Softmax4D(axis=1, name="softmax"))
    else:
        model.add(Flatten())
        model.add(Dense(2048, activation='relu', trainable=trainable_fc, name="dense_1"))
        model.add(Dropout(0.5))
        model.add(Dense(1024, activation='relu', trainable=trainable_fc, name="dense_2"))
        model.add(Dropout(0.5))
        model.add(Dense(10, activation=config['classifier'], trainable=trainable_fc, name="dense_3"))
    return model


def keras_template(config):
    img_chnnel = 3
    if config['to_gray'] == 'True':
        img_chnnel = 1

    model = Sequential()
    model.add(Convolution2D(32, 3, 3, border_mode='same',
                            input_shape=(img_chnnel, config['resize'][1], config['resize'][0])))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Convolution2D(32, 3, 3))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Convolution2D(64, 3, 3, border_mode='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Convolution2D(64, 3, 3))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(10))
    model.add(Activation(config['classifier']))

    if config['weights_path']:
        model.load_weights('/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/ExData/' + config['weights_path'])
    return model


def keras_baseline(config):
    img_chnnel = 3
    if config['to_gray'] == 'True':
        img_chnnel = 1
    nb_classes = 10
    # number of convolutional filters to use
    nb_filters = 8
    # size of pooling area for max pooling
    nb_pool = 2
    # convolution kernel size
    nb_conv = 2
    model = Sequential()
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                            border_mode='valid',
                            input_shape=(img_chnnel, config['resize'][1], config['resize'][0])))
    model.add(Activation('relu'))
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(128))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation(config['classifier']))

    if config['weights_path']:
        model.load_weights('/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/ExData/' + config['weights_path'])
    return model


"""
def ST(config):
    img_chnnel = 3
    if config['to_gray'] == 'True':
        img_chnnel = 1
    # initial weights
    b = np.zeros((2, 3), dtype='float32')
    b[0, 0] = 1
    b[1, 1] = 1
    W = np.zeros((50, 6), dtype='float32')
    weights = [W, b.flatten()]
    locnet = Sequential()
    locnet.add(MaxPooling2D(pool_size=(2,2), input_shape=(img_chnnel, config['resize'][1], config['resize'][0])))
    locnet.add(Convolution2D(20, 5, 5))
    locnet.add(MaxPooling2D(pool_size=(2,2)))
    locnet.add(Convolution2D(20, 5, 5))

    locnet.add(Flatten())
    locnet.add(Dense(50))
    locnet.add(Activation('relu'))
    locnet.add(Dense(6, weights=weights))
    #locnet.add(Activation('sigmoid'))
    model = Sequential()
    model.add(SpatialTransformer(localization_net=locnet,
                                 downsample_factor=2, input_shape=(img_chnnel, config['resize'][1], config['resize'][0])))

    model.add(Convolution2D(32, 3, 3, border_mode='same'))
    model.add(Activation('relu'))
    model.add(Convolution2D(32, 3, 3))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.5))
    model.add(Convolution2D(64, 3, 3, border_mode='same'))
    model.add(Activation('relu'))
    model.add(Convolution2D(64, 3, 3))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.5))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10))
    model.add(Activation('softmax'))

    return model
"""


def choose_model(config):
    if config['model_type'] == 'VGG_16':
        return VGG_16(config)
    elif config['model_type'] == 'KT':
        return keras_template(config)
    elif config['model_type'] == 'KB':
        return keras_baseline(config)
    elif config['model_type'] == 'VGG_16_Ori':
        return VGG_16_Ori(config)
